Punit Model

1. Motivation and implementaion

The P-Unit model is a LIF model which additionally embeds different aspects of the sensory pathway in the weakly electric fish. The input to this model in the baseline condition is a sinus with the frequency of the EOD, where the amplitude is normalized to one:

\[ S(t) = S_{EOD}(t) = \cos(2\pi f_{EOD} t) \]

The P-Units respond to amplitude changes on their carrier EOD. To stimulate the P-Unit model with gaussian white noise one has to multiply the baseline with the amplitude modulation, with a default contrast (\(c\)) of 10%.

\[ S_{am}(t) = S_{EOD}(t) + (S_{EOD}(t) \xi(t) c) \]

This stimulus passes then a threshold operation between the receptor cell and afferent (P-Unit) . Through the afferent dendrite the stimulus is low-pass filtered which is governed by the dendrite time constant \(\tau_{d}\).

\[ \tau_{d} \frac{d V_{d}}{d t} = -V_{d}+ \lfloor S(t) \rfloor_{0}^{p} \]

The resulting voltage has a scaling factor \(\alpha\) and is the input in the LIF. Another addition to the standard LIF model is an adaption current, which is subtracted for the membrane voltage.

\[ \tau_{A} \frac{d A}{d t} = - A \]

Lastly there is a refractory period, where after the membrane voltage \(V_m(t)\) crossed the threshold of \(\theta = 1\), the integration of \(V_m(t)\) is paused. The fixed input bias \(\mu\) and the noise term \(\sqrt{2D}\xi(t)\) is the same as in the standard LIF. This results in the following differential equation:

\[ \tau_{m} \frac{d V_{m}}{d t} = - V_{m} + \mu + \alpha V_{d} - A + \sqrt{2D}\xi(t) \]

3. Example

Here is a minimal example that get you started.

import jax
import jax.numpy as jnp
import jax.scipy as jsp

import jaxon.models.punit as punit
from jaxon.dsp.kernels import gauss_kernel
from jaxon.dsp.rate import spike_rate

duration = 10
# parameters for P-Unit model
punit_params = {
    "cell": "2010-11-08-al-invivo-1",
    "EODf": 744.66,
    "a_zero": 9.450855200303527,
    "delta_a": 0.0604984400793618,
    "dend_tau": 0.0007742334994649,
    "input_scaling": 31.363843698084207,
    "mem_tau": 0.0017257848281706,
    "noise_strength": 0.0124091008125932,
    "ref_period": 0.0010273077926126,
    "deltat": 5e-05,
    "tau_a": 0.1022386553157565,
    "threshold": 1,
    "v_base": 0,
    "v_offset": -0.390625,
    "v_zero": 0,
}
cell = punit_params.pop("cell")
eodf = punit_params.pop("EODf")

params = punit.PUnitParams(**punit_params)
# parameter for kernel
sigma = 0.007
ktime = 4
fs = 1 / params.deltat

# first generate a random key for the LIF model
key = jax.random.PRNGKey(42)
keys = jax.random.split(key, 10)
time = jnp.arange(0, duration, 1 / fs)

stimulus = jnp.cos(2 * jnp.pi * eodf * time)
binary_spikes, vmem = punit.simulate(key, stimulus, params)


kernel = gauss_kernel(sigma, 1 / fs, ktime)

rate = spike_rate(binary_spikes, kernel)

We can now plot the simulation result of the simulation.

import plotly.graph_objects as go
from plotly.subplots import make_subplots

fig = make_subplots(specs=[[{"secondary_y": True}]])
fig.add_scatter(x=time, y=vmem, mode="lines", name="V", secondary_y=False)


fig.add_scatter(
    x=time[binary_spikes.astype(bool)],
    y=vmem[binary_spikes.astype(bool)] + params.threshold,
    mode="markers",
    marker_size=10,
    marker_color="red",
    marker_symbol="arrow-down",
    name="Spikes",
    secondary_y=False,
)
fig.add_scatter(
    x=time, y=rate, name="Rate [Hz]", secondary_y=True, marker_color="magenta", line_width=4
)

fig.update_layout(xaxis_title="Time [s]", yaxis_title="Volatage [aU]")
fig.update_yaxes(title_text="Rate [Hz]", secondary_y=True)
fig.update_xaxes(range=[0, 0.2])